In the blog post "Why is Adam's Update RMS 0.2?", we estimated the Update RMS of Adam using mean-field approximation. Shortly thereafter, reader @EIFY pointed out that similar results had already appeared in the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks". Upon reading it, the author discovered that it not only contains estimates of Update RMS but also includes estimates of Weight RMS.
In other words, the RMS of weights in models trained with AdamW can be asymptotically estimated in advance. Is this conclusion somewhat surprising? The author was certainly surprised upon first encountering it. Intuitively, the weight norm is learned by the model based on the training data, yet the result tells me it is already embedded within the optimizer's hyperparameters—this seems rather counter-intuitive.
In this article, we will use the mean-field approximation method to reproduce the asymptotic estimation of Weight RMS.
Sliding Perspective#
First, let us review the update rule of AdamW:
Note that bold symbols here denote vectors in $\mathbb{R}^d$, and vector multiplication/division (including squaring and square roots) refer to element-wise Hadamard product/quotient.
Similar to "Why is Adam's Update RMS 0.2?", we consider $t\to\infty$ (with respect to $\beta_1,\beta_2$) and $\epsilon\to 0$, so $\boldsymbol{u}_t=\boldsymbol{m}_t/\sqrt{\boldsymbol{v}_t}$. For now, we consider constant $\eta_t,\lambda_t$, so their subscripts can be omitted, and we denote $\beta_3 = 1-\eta\lambda$. We have:
This equation indicates that Weight Decay can be understood from the perspective of exponential moving average (EMA) of update quantities. This is a meaningful perspective shift and forms the basis for works such as "How to set AdamW's weight decay as you scale model and dataset size" and "Power Lines: Scaling Laws for Weight Decay and Batch Size in LLM Pre-training".
Weighted Average#
Based on Equation (2), we can expand $\boldsymbol{\theta}_t$ as a weighted average:
Similarly, $\boldsymbol{m}_t$ and $\boldsymbol{v}_t$ can be expanded as:
A minor detail: we retain $\boldsymbol{\theta}_0$ in the expression for $\boldsymbol{\theta}_t$, but we do not retain $\boldsymbol{m}_0$ and $\boldsymbol{v}_0$ in the expressions for $\boldsymbol{m}_t$ and $\boldsymbol{v}_t$. The reasons are twofold: 1) $\boldsymbol{m}$ and $\boldsymbol{v}$ are typically initialized to zero; 2) even if they are not initialized to zero, the corresponding $\beta_1^t$ and $\beta_2^t$ will be sufficiently close to zero, making the initialization effect negligible.
However, $\boldsymbol{\theta}$ represents model weights, whose initialization is usually not zero, and $\beta_3$ is often very close to 1. For the entire training cycle, $\beta_3^t$ may not sufficiently approach zero, so we explicitly retain $\beta_3^t$ and $\boldsymbol{\theta}_0$, adjusting as needed.
Quick Estimation#
Our task is to estimate Weight RMS, i.e., $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$, which is defined as the Root Mean Square of each component:
The difference from the norm is that it is divided by $\sqrt{d}$, so most properties of the norm also hold for RMS. For $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$, we have a quick but not entirely accurate derivation: directly compute $\Vert\cdot\Vert_{RMS}^2$ on both sides of Equation (2), yielding:
Assuming $\boldsymbol{\theta}_{t-1},\boldsymbol{u}_t$ are nearly orthogonal, then $\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t\approx 0$, which is generally a good approximation in high-dimensional spaces (refer to "Angle Distribution Between Two Random Vectors in n-Dimensional Space"). We have already computed $\Vert\boldsymbol{u}_t\Vert_{RMS}$, which is approximately $\sqrt{\frac{1-\beta_1}{1+\beta_1}}$. Finally, considering the steady-state result, we have $\Vert\boldsymbol{\theta}_t\Vert_{RMS}^2=\Vert\boldsymbol{\theta}_{t-1}\Vert_{RMS}^2$, leading to:
The transition from the left to the right side also uses the approximation $\beta_3\approx 1$. The final result may have some error since $\boldsymbol{\theta}_t\cdot\boldsymbol{u}_t\approx 0$ is not entirely accurate, but the conclusion $\Vert\boldsymbol{\theta}_t\Vert_{RMS}\propto \sqrt{\eta/\lambda}$ is correct. Similar derivations appear in "Why Gradients Rapidly Increase Near the End of Training".
Better Approximation#
In many cases, it is sufficient to know that $\Vert\boldsymbol{\theta}_t\Vert_{RMS}\propto \sqrt{\eta/\lambda}$, which is a relatively general conclusion. For readers seeking more accurate results, we can use mean-field methods to obtain a better approximation, albeit at the cost of more complex computations, but with the benefit of gaining clearer insights.
Step One#
Starting from Equation (3), the summation term itself has a weighted average form, so we first apply mean-field approximation:
Now returning to Equation (3), since $\boldsymbol{\theta}_0$ is a randomly initialized vector, we can assume $\boldsymbol{\theta}_0$ is orthogonal to $\bar{\boldsymbol{u}}_t$, yielding:
Now we need $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$. Based on previous experience, we assume $\boldsymbol{g}_j$ are independently and identically distributed as $\mathcal{N}(\boldsymbol{\mu},\boldsymbol{\sigma}^2)$, then compute:
Finally, averaging the components of $\mathbb{E}[\bar{\boldsymbol{u}}_t^2]$ yields an approximation for $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$.
Step Two#
Combining with Equation (4), we obtain:
Note: For help with simplifying the double summation, one can refer to tools like Kimi (see link). The above shows that $\bar{\boldsymbol{m}}_t,\bar{\boldsymbol{v}}_t$ are weighted averages of gradients and gradient squares, respectively, so computing $\Vert \bar{\boldsymbol{u}}_t\Vert_{RMS}^2$ is essentially the same as computing $\Vert \boldsymbol{u}_t\Vert_{RMS}^2$ in "Why is Adam's Update RMS 0.2?", only with different weighting coefficients.
Step Three#
First, we compute the denominator:
The approximation in the last step is because in practical training, $\beta_3$ is sufficiently close to 1, while $\beta_2^{t+1}$ is sufficiently close to 0, but $\beta_3^{t+1}$ may not be. Thus, we replace $\beta_2^{t+1}$ with zero, simplify, replace independent $\beta_3$ with 1, and finally apply the approximation $\beta_3^{t+1}\approx \beta_3^t$.
Step Four#
Next, $\mathbb{E}[\bar{\boldsymbol{m}}_t^2] = \mathbb{E}[\bar{\boldsymbol{m}}_t]^2 + \mathbb{V}ar[\bar{\boldsymbol{m}}_t]$. The computation of $\mathbb{E}[\bar{\boldsymbol{m}}_t]$ is similar to $\mathbb{E}[\bar{\boldsymbol{v}}_t]$, resulting in $\boldsymbol{\mu}$. For $\mathbb{V}ar[\bar{\boldsymbol{m}}_t]$, we use the additivity of variance:
The approximation follows similar reasoning as above.
Step Five#
Substituting the previous results, we obtain:
Thus:
Finally:
Result Analysis#
Equation (16) appears complex; let us examine a few special cases. First, consider $\boldsymbol{\mu}=\boldsymbol{0}$:
Specifically, considering $t\to\infty$, or if $\Vert\boldsymbol{\theta}_0\Vert_{RMS}^2$ is initialized to $\eta/2\lambda$, we have:
This matches the result in the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks", consistent with the paper's assumptions—it is the steady-state result of a random walk with zero mean. If we instead consider the limit $\lambda\to 0$ rather than $t\to\infty$, then from Equation (17) we obtain:
This indicates that without Weight Decay, $\Vert\boldsymbol{\theta}_t\Vert_{RMS}$ grows approximately at the rate $\eta\sqrt{t}$, suggesting that in the absence of Weight Decay, we can achieve stability in Weight RMS by designing specific learning rate schedules. On the other hand, if the batch size is sufficiently large such that the signal-to-noise ratio term $\Vert\boldsymbol{\mu}\Vert^2/\Vert\boldsymbol{\sigma}\Vert^2$ dominates, then from Equation (16):
This may apply to special cases where the model needs to actively increase Weight RMS. However, empirical observations suggest such scenarios are generally rare.
Simulation#
We can use the following simulation script to validate the accuracy of the above derivation:
import numpy as np
N, T = 10000, 100000
beta1, beta2 = 0.9, 0.95
m, v = 0, 0
w = np.random.randn(N) * 0.1
for i in range(T):
g = np.random.randn(N)
m = beta1 * m + (1 - beta1) * g
v = beta2 * v + (1 - beta2) * g**2
w = w - 0.001 * (m / v**0.5 + 0.1 * w)
weight_rms = (w**2).mean()**0.5
print(weight_rms)
Readers can modify weight initialization or gradient mean/variance to observe how well the final results align with Equation (16). The author has conducted tests, and overall, the agreement is quite reliable.
SignSGDM Version#
With minor adjustments, the preceding derivation can be adapted to the "SignSGDM + Weight Decay" combination:
The modification arises because $\sign(\boldsymbol{m}_t)=\boldsymbol{m}_t/\sqrt{\boldsymbol{m}_t^2}$, so we redefine $\bar{\boldsymbol{v}}_t$ as:
Then:
Here, $\mathbb{E}[\boldsymbol{m}_i^2]$ can be computed by referring to "Why is Adam's Update RMS 0.2?" or "Rethinking Learning Rate and Batch Size (Part 4): EMA". Using the above result, we obtain:
In particular, considering the limit $\boldsymbol{\mu}=0, t\to\infty$, we have:
This result is reasonable because the Update RMS of SignSGDMW is $\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$ times that of AdamW, so for the same $\eta,\lambda$, its Weight RMS is also $\sqrt{\frac{1+\beta_1}{1 - \beta_1}}$ times larger.
Related Analysis#
As mentioned, result (18) aligns with the paper "Rotational Equilibrium: How Weight Decay Balances Learning Across Neural Networks", but our derivation method is entirely different and yields the more general Equation (16). However, the original paper contains interesting concepts, such as the Total Update Contribution (TUC), which merits discussion.
The TUC idea is as follows: due to momentum, the current gradient $\boldsymbol{g}_t$ does not only affect the current step; it also influences future steps (though with a "discount"). Thus, assuming the number of training steps tends to infinity, we can consider the total contribution of the current gradient $\boldsymbol{g}_t$ to the entire training process. Specifically, for Adam, we have $\boldsymbol{u}_t=\boldsymbol{m}_t/\sqrt{\boldsymbol{v}_t}$. The contribution of $\boldsymbol{g}_t$ to $\boldsymbol{u}_t$ is $(1-\beta_1)\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_t}$. In the next step, $\boldsymbol{g}_t$ will be discounted (multiplied by $\beta_1$), and the denominator changes to $\boldsymbol{v}_{t+1}$, and so on. Thus, we can define the total contribution as:
This decomposes updates $\boldsymbol{u}_1,\boldsymbol{u}_2,\boldsymbol{u}_3,\cdots$ into $\tilde{\boldsymbol{u}}_1,\tilde{\boldsymbol{u}}_2,\tilde{\boldsymbol{u}}_3,\cdots$. The advantage is that each $\tilde{\boldsymbol{u}}$ involves only a single gradient step, allowing us to repeat the derivation from the Quick Estimation section:
The final approximation relies on $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$. We argue that $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t$ is closer to zero than $\boldsymbol{\theta}_{t-1}\cdot\boldsymbol{u}_t$ because $\tilde{\boldsymbol{u}}_t$ depends only on the current gradient $\boldsymbol{g}_t$, while $\boldsymbol{\theta}_{t-1}$ has not yet been exposed to $\boldsymbol{g}_t$. Thus, they are independent variables, and when $\boldsymbol{g}_t$ has zero mean, $\boldsymbol{\theta}_{t-1}\cdot\tilde{\boldsymbol{u}}_t\approx 0$ generally holds. To estimate $\Vert\tilde{\boldsymbol{u}}_t\Vert_{RMS}^2$, the original paper assumes $\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}$ share the same direction and have unit RMS, leading to:
Substituting into Equation (27) and applying similar approximations as in the Quick Estimation section yields:
However, if one only reads the original paper, many approximations may seem inexplicable. For instance, $\boldsymbol{v}_t$ also contains $\boldsymbol{g}_t$, so claiming $\tilde{\boldsymbol{u}}_t$ only involves the current $\boldsymbol{g}_t$ is not entirely accurate, and the assertion $\Vert\boldsymbol{g}_t/\sqrt{\boldsymbol{v}_k}\Vert_{RMS}=1$ appears somewhat arbitrary. Yet, from the perspective of this article, we see that under mean-field approximation, the operations in the original paper become reasonable; thus, the original paper implicitly employs mean-field methods.
Summary#
In this article, we used mean-field approximation to derive an interesting and perhaps surprising conclusion: the RMS of weights in models trained with AdamW can be asymptotically estimated. In general cases, it depends only on the learning rate and Weight Decay.
Original Article: Su Jianlin. Asymptotic Estimation of Weight RMS in AdamW (Part 1). Scientific Spaces.
How to cite this translation:
BibTeX: